

import time 
import logging
import os
import random
import torch
import torch.utils.data
from . import base 
from tqdm import tqdm
import numpy as np
import copy
import pytorch3d.transforms as pytorch3d

import pandas as pd 
import csv

class PCloader(base.Dataset):

    def __init__(
        self,
        data_source,
        split_file, 
        pc_size=1024,
        return_filename=False
    ):
        self.pc_size = pc_size
        self.gt_files = self.get_instance_filenames(data_source, split_file)
        self.return_filename = return_filename
        self.pc_paths = self.get_instance_filenames(data_source, split_file)
        
        print("loading {} point clouds into memory...".format(len(self.pc_paths)))

        

    def get_all_files(self):
        return self.point_clouds, self.pc_paths 
    
    def __getitem__(self, idx): 
        data_dict = {}
        data = np.load(self.pc_paths[idx])
        PC_da = torch.from_numpy(data['camera_partial_pc'])
        gt_R_da = torch.from_numpy(data['base_pose'][:3, :3])
        gt_t_da = torch.from_numpy(data['base_pose'][:3, 3])  
        gt_xyz = torch.from_numpy(data['joint_xyz'][1])
        gt_rpy = torch.from_numpy(data['joint_rpy'][1])
        gt_seg = torch.from_numpy(data['cls'])
        gt_full_seg = torch.from_numpy(data['full_seg'])
        atc = torch.from_numpy(data['atc'])
        rot = pytorch3d.matrix_to_rotation_6d(gt_R_da.permute(1, 0))
        location = gt_t_da  
        data_dict['gt_pose'] = torch.cat([rot.float(), location.float()], dim=-1)  
        data_dict['pts'] = PC_da
        data_dict['gt_xyz'] = gt_xyz
        data_dict['gt_rpy'] = gt_rpy
        data_dict['seg'] = gt_seg
        data_dict['full_seg'] = gt_full_seg
        data_dict['atc'] = data['atc'][1:]
        data_dict['gt_rot'] = gt_R_da
        data_dict['gt_trans'] = gt_t_da
        """ zero center """
        num_pts = data_dict['pts'].shape[0]
        zero_mean = torch.mean(data_dict['pts'][:, :3], dim=0)
        data_dict['zero_mean_pts'] = copy.deepcopy(data_dict['pts'])
        data_dict['zero_mean_pts'][:, :3] -= zero_mean.unsqueeze(0).repeat(num_pts, 1)
        data_dict['zero_mean_gt_pose'] = copy.deepcopy(data_dict['gt_pose'])
        data_dict['zero_mean_gt_pose'][-3:] -= zero_mean
        data_dict['zero_mean_gt_xyz'] = copy.deepcopy(data_dict['gt_xyz'])
        data_dict['zero_mean_gt_xyz'] -= zero_mean
        data_dict['zero_mean_gt_rpy'] = copy.deepcopy(data_dict['gt_rpy'])
        data_dict['pts_center'] = zero_mean
        data_dict['file_name'] = self.pc_paths[idx]

        data_dict['point_cloud'] = torch.from_numpy(data['point_cloud']).float()

        return data_dict


    def __len__(self):
        return len(self.gt_files)


    def sample_pc(self, f, samp=1024):
        pc = torch.from_numpy(np.load(f)['point_cloud']).float()
        atc = torch.from_numpy(np.load(f)['atc']).float()[1]
        pc_idx = torch.randperm(pc.shape[0])[:samp]
        pc = pc[pc_idx]
        return pc, atc



    
